"""Evaluation Functions """

import os
import os.path as osp
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import cv2
import gym
import numpy as np
from tqdm import tqdm

from diffgro.utils.utils import print_r, print_y, print_b
from diffgro.utils import make_dir, save_image, save_video, write_annotation


def evaluate(
    policy,
    env: gym.Env,
    domain_name: str,
    task_name: str,
    n_episodes: int = 10,
    deterministic: bool = True,
    video: bool = False,
    save_path: str = None,
    context: Tuple[Any] = None,
) -> Tuple[np.array, np.array]:
    tot_success, tot_length, frames = [], [], []
    tot_speed, tot_force, tot_force_axis, tot_energy, tot_damp = [], [], [], [], []
    tot_context_rew, tot_damp_rew, tot_damp_rew_com = [], [], []

    print(video)

    pbar = tqdm(total=n_episodes)
    for episode in range(n_episodes):
        obs, done, step = env.reset(), False, 0
        speed, force, force_axis, energy, actions = [], [], [], [], []
        damp, damp_rew, damp_rew_com = [], [], []
        if getattr(policy, "reset", None) is not None:
            # if policy have reset function
            policy.reset()

        frames = []
        while not done:
            action, _, p_info = policy.predict(obs, deterministic=deterministic)
            action = np.array(action.copy())
            obs, reward, done, e_info = env.step(action)
            step += 1
            speed.append(e_info["speed"])
            force.append(e_info["force"])
            force_axis.append(e_info["force_axis"])
            energy.append(e_info["energy"])
            actions.append(action[:3])
            if "damping_rew" in e_info.keys():
                damp_rew.append(e_info["damping_rew"])
                damp_rew_com.append(e_info["damping_rew_com"])

            if "damping" in e_info.keys() and "guided" in p_info.keys():
                if e_info["damping"] == p_info["guided"]:
                    damp.append(1.0)
                else:
                    damp.append(0.0)

            if done:
                break

            if video:
                frame = env.render()
                if "h" in p_info.keys():
                    frame = write_annotation(frame, str(p_info["h"]))
                frames.append(frame)

        tot_success.append(e_info["success"])
        tot_length.append(step)
        tot_speed.append(np.mean(speed))
        tot_force_axis.append(np.mean(force_axis, axis=0))
        tot_force.append(np.mean(force))
        tot_energy.append(np.mean(energy))
        tot_damp.append(np.mean(damp))
        tot_damp_rew.append(np.mean(damp_rew))
        tot_damp_rew_com.append(np.mean(damp_rew_com))
        # if context
        if context is not None:
            rew = evaluate_context(
                context, tot_force[-1], tot_force_axis[-1], tot_energy[-1], actions
            )
            tot_context_rew.append(rew)
            print_y(f"Context Reward: {rew}, Success: {e_info['success']}")
        else:
            tot_context_rew.append(0.0)

        pbar.update(1)
        pbar.set_description(f"Episodes: {episode + 1}/{n_episodes}")

        if video:
            print_b(f"Saving video at {save_path}")
            video_folder = osp.join(save_path, "video")
            make_dir(video_folder)
            video_path = osp.join(video_folder, f"{task_name}_{episode}.mp4")
            save_video(video_path, frames)

            image_folder = osp.join(save_path, "image", f"{task_name}")
            make_dir(image_folder)
            image_path = image_folder
            save_image(image_path, frames)

    avg_success, std_success = (
        np.mean(tot_success, axis=0) * 100,
        np.std(tot_success, axis=0) * 100,
    )
    avg_length, std_length = np.mean(tot_length, axis=0), np.std(tot_length, axis=0)
    avg_damp, std_damp = np.mean(tot_damp, axis=0), np.std(tot_damp, axis=0)
    avg_damp_rew, std_damp_rew = np.mean(tot_damp_rew, axis=0), np.std(
        tot_damp_rew, axis=0
    )
    avg_damp_rew_com, std_damp_rew_com = np.mean(tot_damp_rew_com, axis=0), np.std(
        tot_damp_rew_com, axis=0
    )

    (
        success_length,
        success_action,
        success_force,
        success_speed,
        success_energy,
        success_context_rew,
    ) = ([], [], [], [], [], [])
    for success, length, force, speed, energy, action, context_rew in zip(
        tot_success,
        tot_length,
        tot_force,
        tot_speed,
        tot_energy,
        tot_force_axis,
        tot_context_rew,
    ):
        if success:
            success_length.append(length)
            success_force.append(force)
            success_speed.append(speed)
            success_energy.append(energy)
            success_action.append(action)
            success_context_rew.append(context_rew)
    avg_success_length, std_success_length = np.mean(success_length, axis=0), np.std(
        success_length, axis=0
    )
    avg_success_force, std_success_force = np.mean(success_force, axis=0), np.std(
        success_force, axis=0
    )
    avg_success_speed, std_success_speed = np.mean(success_speed, axis=0), np.std(
        success_speed, axis=0
    )
    avg_success_energy, std_success_energy = np.mean(success_energy, axis=0), np.std(
        success_energy, axis=0
    )
    avg_success_action, std_success_action = np.mean(success_action, axis=0), np.std(
        success_action, axis=0
    )
    avg_context_rew, std_context_rew = np.mean(tot_context_rew, axis=0), np.std(
        tot_context_rew, axis=0
    )
    avg_success_context_rew, std_success_context_rew = np.mean(
        success_context_rew, axis=0
    ), np.std(success_context_rew, axis=0)

    print("")
    print_b("=" * 13 + f" Performance Evaluation " + "=" * 13)
    print_r(f"\t{domain_name}.{task_name}")
    print(f"\tTotal Length: {avg_length} +\- {std_length}")
    print(f"\tTotal Success Rate : {avg_success} +\- {std_success}")
    print(f"\tSucess Length: {avg_success_length} +\- {std_success_length}")
    print(f"\tSucess Force: {avg_success_force:.3f} +\- {std_success_force:.3f}")
    print(f"\tSucess Speed: {avg_success_speed:.3f} +\- {std_success_speed:.3f}")
    print(f"\tSucess Energy: {avg_success_energy:.3f} +\- {std_success_energy:.3f}")
    np.set_printoptions(precision=3, suppress=True)
    print(f"\tSucess Action: {avg_success_action} +\- {std_success_action}")
    print(f"\tTotal Damp: {avg_damp:.3f} +\- {std_damp:.3f}")
    print(f"\tTotal Damp Rew: {avg_damp_rew:.3f} +\- {std_damp_rew:.3f}")
    print(f"\tTotal Damp Com: {avg_damp_rew_com:.3f} +\- {std_damp_rew_com:.3f}")
    if context is not None:
        print(f"\tContext Reward: {avg_context_rew:.3f} +\- {std_context_rew:.3f}")
        print(
            f"\tContext Success Reward: {avg_success_context_rew:.3f} +\- {std_success_context_rew:.3f}"
        )
    print_b("=" * 50)

    # save text file
    with open(os.path.join(save_path, "evaluation.txt"), "a") as f:
        f.write("=" * 13 + f" Performance Evaluation" + "=" * 13 + "\n")
        f.write(f"\t{domain_name}.{task_name}\n")
        f.write(f"\tTotal Length: {avg_length} +\- {std_length}\n")
        f.write(f"\tTotal Success Rate : {avg_success} +\- {std_success}\n")
        f.write(f"\tSucess Length: {avg_success_length} +\- {std_success_length}\n")
        f.write(
            f"\tSucess Force: {avg_success_force:.3f} +\- {std_success_force:.3f}\n"
        )
        f.write(
            f"\tSucess Speed: {avg_success_speed:.3f} +\- {std_success_speed:.3f}\n"
        )
        f.write(
            f"\tSucess Energy: {avg_success_energy:.3f} +\- {std_success_energy:.3f}\n"
        )
        np.set_printoptions(precision=3, suppress=True)
        f.write(f"\tSuccess Action: {avg_success_action} +\- {std_success_action}\n")
        f.write(f"\tTotal Damp: {avg_damp:.3f} +\- {std_damp:.3f}\n")
        f.write(f"\tTotal Damp Rew: {avg_damp_rew:.3f} +\- {std_damp_rew:.3f}\n")
        f.write(
            f"\tTotal Damp Com: {avg_damp_rew_com:.3f} +\- {std_damp_rew_com:.3f}\n"
        )
        if context is not None:
            f.write(
                f"\tContext Reward: {avg_context_rew:.3f} +\- {std_context_rew:.3f}\n"
            )
            f.write(
                f"\tContext Success Reward: {avg_success_context_rew:.3f} +\- {std_success_context_rew:.3f}\n"
            )
        f.write("=" * 50 + "\n")

    if context is not None:
        return tot_success, success_context_rew
    return tot_success


def evaluate_complex(
    policy,
    env: gym.Env,
    domain_name: str,
    task_name: str,
    n_episodes: int = 10,
    deterministic: bool = True,
    video: bool = False,
    save_path: str = None,
    context: List[Tuple[Any]] = None,
) -> Tuple[np.array, np.array]:
    tot_success, tot_length, frames = [], [], []
    tot_speed, tot_force, tot_force_axis, tot_energy = [], [], [], []
    tot_context_rew, tot_damp, tot_damp_rew = [], [], []

    pbar = tqdm(total=n_episodes)
    for episode in range(n_episodes):
        obs, done, step = env.reset(), False, 0
        force, force_axis, actions, sub_goal = [], [], [], []
        damp, damp_rew = [], []
        _force, _force_axis, _actions = [], [], []
        if getattr(policy, "reset", None) is not None:
            # if policy have reset function
            policy.reset()

        frames = []
        while not done:
            action, _, p_info = policy.predict(obs, deterministic=deterministic)
            action = np.array(action.copy())
            obs, reward, done, e_info = env.step(action)
            step += 1
            _force.append(e_info["force"])
            _force_axis.append(e_info["force_axis"])
            _actions.append(action[:3])
            if "damping_rew" in e_info.keys():
                damp_rew.append(e_info['damping_rew'])

            if "damping" in e_info.keys() and "guided" in p_info.keys():
                if e_info["damping"] == p_info["guided"]:
                    damp.append(1.0)
                else:
                    damp.append(0.0)

            if e_info["sub_goal_success"]:
                force.append(np.mean(_force))
                force_axis.append(np.mean(_force_axis, axis=0))
                actions.append(_actions)
                _force, _force_axis, _actions = [], [], []

            if done:
                for _ in range(4 - len(force)):
                    force.append(None)
                    force_axis.append(None)
                    actions.append(None)
                break

            if video:
                frame = env.render()
                if "h" in p_info.keys():
                    frame = write_annotation(frame, str(p_info["h"]))
                frames.append(frame)

        tot_success.append(e_info["success"])
        tot_length.append(step)
        tot_force.append(force)
        tot_force_axis.append(force_axis)
        tot_damp.append(np.mean(damp))
        tot_damp_rew.append(np.mean(damp_rew))

        # if context evaluation
        if context is not None:
            rew = []
            for i, (act, ctx) in enumerate(zip(actions, context)):
                if act is None:
                    r = None
                else:
                    r = evaluate_context(
                        ctx, tot_force[-1][i], tot_force_axis[-1][i], None, act[i]
                    )
                rew.append(r)
            tot_context_rew.append(rew)
            np.set_printoptions(precision=3, suppress=True)
            print_y(f"Context Reward: {rew}, Success: {e_info['success']}")
        else:
            if e_info["success"] == 0.0:
                tot_context_rew.append([None, None, None, None])
            if e_info["success"] == 0.25:
                tot_context_rew.append([1.0, None, None, None])
            if e_info["success"] == 0.50:
                tot_context_rew.append([1.0, 1.0, None, None])
            if e_info["success"] == 0.75:
                tot_context_rew.append([1.0, 1.0, 1.0, None])
            elif e_info["success"] == 1.0:
                tot_context_rew.append([1.0, 1.0, 1.0, 1.0])

        pbar.update(1)
        pbar.set_description(f"Episodes: {episode + 1}/{n_episodes}")

        if video:
            print_b(f"Saving video at {save_path}")
            video_folder = osp.join(save_path, "video")
            make_dir(video_folder)
            video_path = osp.join(video_folder, f"{task_name}_{episode}.mp4")
            save_video(video_path, frames)

            image_folder = osp.join(save_path, "image", f"{task_name}")
            make_dir(image_folder)
            image_path = image_folder
            save_image(image_path, frames)

    avg_success, std_success = (
        np.mean(tot_success, axis=0) * 100,
        np.std(tot_success, axis=0) * 100,
    )
    avg_length, std_length = np.mean(tot_length, axis=0), np.std(tot_length, axis=0)
    avg_damp, std_damp = np.mean(tot_damp, axis=0), np.std(tot_damp, axis=0)
    avg_damp_rew, std_damp_rew = np.mean(tot_damp_rew, axis=0), np.std(tot_damp_rew, axis=0)

    success_length = []
    success_force = [[], [], [], []]
    success_force_axis = [[], [], [], []]
    success_context_rew = [[], [], [], []]
    for success, length, force, force_axis, context_rew in zip(
        tot_success, tot_length, tot_force, tot_force_axis, tot_context_rew
    ):
        for i, ctx_rew in enumerate(context_rew):
            if ctx_rew is not None:
                success_force[i].append(force[i])
                success_force_axis[i].append(force_axis[i])
                success_context_rew[i].append(context_rew[i])

        if success:
            success_length.append(length)

    avg_success_length, std_success_length = np.mean(success_length, axis=0), np.std(
        success_length, axis=0
    )
    for i in range(len(success_force)):
        success_force[i] = (
            np.mean(success_force[i], axis=0),
            np.std(success_force[i], axis=0),
        )
        success_force_axis[i] = (
            np.mean(success_force_axis[i], axis=0),
            np.std(success_force_axis[i], axis=0),
        )
        success_context_rew[i] = (
            np.mean(success_context_rew[i], axis=0),
            np.std(success_context_rew[i], axis=0),
        )

    _success_context_rew_avg = []
    _success_context_rew_std = []
    for i in range(len(success_context_rew)):
        if not np.isnan(success_context_rew[i][0]):
            _success_context_rew_avg.append(success_context_rew[i][0])
            _success_context_rew_std.append(success_context_rew[i][1])
    avg_success_context_rew, std_success_context_rew = np.mean(
        _success_context_rew_avg, axis=0
    ), np.mean(_success_context_rew_std, axis=0)

    print("")
    print_b("=" * 13 + f" Performance Evaluation " + "=" * 13)
    print_r(f"\t{domain_name}.{task_name}")
    print(f"\tTotal Length: {avg_length:.1f} +\- {std_length:.1f}")
    print(f"\tTotal Success Rate : {avg_success:.1f} +\- {std_success:.1f}")
    print(f"\tSucess Length: {avg_success_length:.1f} +\- {std_success_length:.1f}")
    np.set_printoptions(precision=3, suppress=True)
    print(
        f"\tSuccess Force at Task 1: {success_force[0][0]:.3f} +\- {success_force[0][1]:.3f}"
    )
    print(
        f"\tSuccess Force at Task 2: {success_force[1][0]:.3f} +\- {success_force[1][1]:.3f}"
    )
    print(
        f"\tSuccess Force at Task 3: {success_force[2][0]:.3f} +\- {success_force[2][1]:.3f}"
    )
    print(
        f"\tSuccess Force at Task 4: {success_force[3][0]:.3f} +\- {success_force[3][1]:.3f}"
    )
    print(
        f"\tSuccess Action at Task 1: {success_force_axis[0][0]} +\- {success_force_axis[0][1]}"
    )
    print(
        f"\tSuccess Action at Task 2: {success_force_axis[1][0]} +\- {success_force_axis[1][1]}"
    )
    print(
        f"\tSuccess Action at Task 3: {success_force_axis[2][0]} +\- {success_force_axis[2][1]}"
    )
    print(
        f"\tSuccess Action at Task 4: {success_force_axis[3][0]} +\- {success_force_axis[3][1]}"
    )
    print(f"\tTotal Damp: {avg_damp:.3f} +\- {std_damp:.3f}")
    print(f"\tTotal Damp Rew: {avg_damp_rew:.3f} +\- {std_damp_rew:.3f}")
    if context is not None:
        print(
            f"\tContext Success Reward at Task 1: {success_context_rew[0][0]:.3f} +\- {success_context_rew[0][1]:.3f}"
        )
        print(
            f"\tContext Success Reward at Task 2: {success_context_rew[1][0]:.3f} +\- {success_context_rew[1][1]:.3f}"
        )
        print(
            f"\tContext Success Reward at Task 3: {success_context_rew[2][0]:.3f} +\- {success_context_rew[2][1]:.3f}"
        )
        print(
            f"\tContext Success Reward at Task 4: {success_context_rew[3][0]:.3f} +\- {success_context_rew[3][1]:.3f}"
        )
        print(
            f"\tAverage Context Success Reward: {avg_success_context_rew:.3f} +\- {std_success_context_rew:.3f}"
        )
    print_b("=" * 50)

    # save text file
    make_dir(save_path)
    with open(os.path.join(save_path, "evaluation.txt"), "a") as f:
        f.write("=" * 13 + f" Performance Evaluation" + "=" * 13 + "\n")
        f.write(f"\t{domain_name}.{task_name}\n")
        f.write(f"\tTotal Length: {avg_length:.1f} +\- {std_length:.1f}\n")
        f.write(f"\tTotal Success Rate : {avg_success:.1f} +\- {std_success:.1f}\n")
        f.write(
            f"\tSucess Length: {avg_success_length:.1f} +\- {std_success_length:.1f}\n"
        )
        np.set_printoptions(precision=3, suppress=True)
        f.write(
            f"\tSuccess Force at Task 1: {success_force[0][0]:.3f} +\- {success_force[0][1]:.3f}\n"
        )
        f.write(
            f"\tSuccess Force at Task 2: {success_force[1][0]:.3f} +\- {success_force[1][1]:.3f}\n"
        )
        f.write(
            f"\tSuccess Force at Task 3: {success_force[2][0]:.3f} +\- {success_force[2][1]:.3f}\n"
        )
        f.write(
            f"\tSuccess Force at Task 4: {success_force[3][0]:.3f} +\- {success_force[3][1]:.3f}\n"
        )
        f.write(
            f"\tSuccess Action at Task 1: {success_force_axis[0][0]} +\- {success_force_axis[0][1]}\n"
        )
        f.write(
            f"\tSuccess Action at Task 2: {success_force_axis[1][0]} +\- {success_force_axis[1][1]}\n"
        )
        f.write(
            f"\tSuccess Action at Task 3: {success_force_axis[2][0]} +\- {success_force_axis[2][1]}\n"
        )
        f.write(
            f"\tSuccess Action at Task 4: {success_force_axis[3][0]} +\- {success_force_axis[3][1]}\n"
        )
        f.write(f"\tTotal Damp: {avg_damp:.3f} +\- {std_damp:.3f}\n")
        f.write(f"\tTotal Damp Rew: {avg_damp_rew:.3f} +\- {std_damp_rew:.3f}\n")
        if context is not None:
            f.write(
                f"\tContext Success Reward at Task 1: {success_context_rew[0][0]:.3f} +\- {success_context_rew[0][1]:.3f}\n"
            )
            f.write(
                f"\tContext Success Reward at Task 2: {success_context_rew[1][0]:.3f} +\- {success_context_rew[1][1]:.3f}\n"
            )
            f.write(
                f"\tContext Success Reward at Task 3: {success_context_rew[2][0]:.3f} +\- {success_context_rew[2][1]:.3f}\n"
            )
            f.write(
                f"\tContext Success Reward at Task 4: {success_context_rew[3][0]:.3f} +\- {success_context_rew[3][1]:.3f}\n"
            )
            f.write(
                f"\tAverage Context Success Reward: {avg_success_context_rew:.3f} +\- {std_success_context_rew:.3f}\n"
            )
        f.write("=" * 50 + "\n")

    if context is not None:
        return tot_success, success_context_rew
    return tot_success


def evaluate_context(context, avg_force, avg_force_act, avg_energy, actions):
    context, context_type, context_target = context[1:4]
    print_y(
        f"[evaluate context fn] Evaluating on Context {context} with type {context_type} and target {context_target}"
    )
    forces = np.linalg.norm(actions, axis=-1)
    if context_type == "slower":
        rew = np.max([0.0, avg_force - context_target])
        rew = 1.0 - (rew * 10)
    elif context_type == "faster":
        rew = np.min([0.0, avg_force - context_target])
        rew = 1.0 - (np.abs(rew) * 10)
    elif context_type == "x-axis slower":
        rew = np.max([0, avg_force_act[0] - context_target])
        rew = 1.0 - (rew * 10)
    elif context_type == "x-axis faster":
        rew = np.min([0, avg_force_act[0] - context_target])
        rew = 1.0 - (np.abs(rew) * 10)
    elif context_type == "y-axis slower":
        rew = np.max([0, avg_force_act[1] - context_target])
        rew = 1.0 - (rew * 10)
    elif context_type == "y-axis faster":
        rew = np.min([0, avg_force_act[1] - context_target])
        rew = 1.0 - (np.abs(rew) * 10)
    elif context_type == "speed below":
        rew = np.max([0.0, avg_force - context_target])
        rew = 1.0 - (rew * 10)
    elif context_type == "speed above":
        rew = np.min([0.0, avg_force - context_target])
        rew = 1.0 - (np.abs(rew) * 10)
    elif context_type == "speed above and below":
        rew1 = np.min([0, avg_force - context_target[0]])
        rew2 = np.max([0, avg_force - context_target[1]])
        rew = 0.5 * (1.0 - (np.abs(rew1) * 10)) + 0.5 * (1.0 - (rew2 * 10))
    else:
        raise NotImplementedError
    return rew
